import streamlit as st
from PIL import Image
from utils import *

from model import ArtNet

def preprocess_image(image_path):
    # Load and preprocess the image
    image = Image.open(image_path)
    input_tensor = transformer(image)
    input_batch = input_tensor.unsqueeze(0)  # Add a batch dimension
    return input_batch

model = ArtNet(11)
model_path = "/home/shivam-wiz/Downloads/MLPR___/Trial/best_checkpoint.model"    
model.load_state_dict(torch.load(model_path))
model.eval()

def main():
    # Page configuration
    st.set_page_config(
        page_title="Plaksha Vista",
        page_icon="🏫",
        layout="wide"
    )

    # Main title and subheading
    st.title("Plaksha Vista")
    st.markdown("## Your Guide to the Plaksha Campus")

    # Sidebar with upload instructions
    st.sidebar.markdown("### Instructions to Upload a Picture")
    st.sidebar.markdown(
        """
        1. Click the 'Browse Files' button below.
        2. Upload an image.
        3. View the details of the uploaded image on the right side.
        """
    )

    # Use columns to structure the layout
    col1,_, col2 = st.columns([3, 1, 6])

    # Upload image on the left side
    uploaded_file = col1.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])

    if uploaded_file:
        # Display the uploaded image
        col1.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
        with open("input_img.jpg", "wb") as f:
            f.write(uploaded_file.read())

    # Output section on the right side
    with col2:
        st.subheader("Image Details")
        
        # Title and description display (plain text)
        image_title = " "
        image_description = " "

        # Display entered title and description below the image
        if uploaded_file and image_title and image_description:
            with torch.no_grad():
                input_image = preprocess_image("input_img.jpg")
                output = model(input_image)
                probabilities = torch.nn.functional.softmax(output[0], dim=0)
                predicted_class = torch.argmax(probabilities).item()
                image_title = classes[predicted_class]
                image_description = data_description[image_title]
                image_title = class_description[classes[predicted_class]]
            st.markdown(f"<h4>Title: {image_title}</h4>", unsafe_allow_html=True)
            st.markdown(f"<h5>Description:\n{image_description}</h5>", unsafe_allow_html=True)

if __name__ == "__main__":
    main()